{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Install the library\n",
    "%pip install pythae"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchvision.datasets as datasets\n",
    "\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mnist_trainset = datasets.MNIST(root='../../data', train=True, download=True, transform=None)\n",
    "\n",
    "train_dataset = mnist_trainset.data[:-10000].reshape(-1, 1, 28, 28) / 255.\n",
    "eval_dataset = mnist_trainset.data[-10000:].reshape(-1, 1, 28, 28) / 255."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pythae.models import AE, AEConfig\n",
    "from pythae.trainers import BaseTrainerConfig\n",
    "from pythae.pipelines.training import TrainingPipeline\n",
    "from pythae.models.nn.benchmarks.mnist import Encoder_ResNet_AE_MNIST, Decoder_ResNet_AE_MNIST"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "config = BaseTrainerConfig(\n",
    "    output_dir='my_model',\n",
    "    learning_rate=1e-4,\n",
    "    per_device_train_batch_size=64,\n",
    "    per_device_eval_batch_size=64,\n",
    "    num_epochs=10, # Change this to train the model a bit more\n",
    ")\n",
    "\n",
    "\n",
    "model_config = AEConfig(\n",
    "    input_dim=(1, 28, 28),\n",
    "    latent_dim=16\n",
    ")\n",
    "\n",
    "model = AE(\n",
    "    model_config=model_config,\n",
    "    encoder=Encoder_ResNet_AE_MNIST(model_config), \n",
    "    decoder=Decoder_ResNet_AE_MNIST(model_config) \n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pipeline = TrainingPipeline(\n",
    "    training_config=config,\n",
    "    model=model\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pipeline(\n",
    "    train_data=train_dataset,\n",
    "    eval_data=eval_dataset\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from pythae.models import AutoModel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "last_training = sorted(os.listdir('my_model'))[-1]\n",
    "trained_model = AutoModel.load_from_folder(os.path.join('my_model', last_training, 'final_model'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pythae.samplers import NormalSampler"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# create normal sampler\n",
    "normal_samper = NormalSampler(\n",
    "    model=trained_model\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# sample\n",
    "gen_data = normal_samper.sample(\n",
    "    num_samples=25\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# show results with normal sampler\n",
    "fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))\n",
    "\n",
    "for i in range(5):\n",
    "    for j in range(5):\n",
    "        axes[i][j].imshow(gen_data[i*5 +j].cpu().squeeze(0), cmap='gray')\n",
    "        axes[i][j].axis('off')\n",
    "plt.tight_layout(pad=0.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pythae.samplers import GaussianMixtureSampler, GaussianMixtureSamplerConfig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# set up GMM sampler config\n",
    "gmm_sampler_config = GaussianMixtureSamplerConfig(\n",
    "    n_components=10\n",
    ")\n",
    "\n",
    "# create gmm sampler\n",
    "gmm_sampler = GaussianMixtureSampler(\n",
    "    sampler_config=gmm_sampler_config,\n",
    "    model=trained_model\n",
    ")\n",
    "\n",
    "# fit the sampler\n",
    "gmm_sampler.fit(train_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# sample\n",
    "gen_data = gmm_sampler.sample(\n",
    "    num_samples=25\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# show results with gmm sampler\n",
    "fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))\n",
    "\n",
    "for i in range(5):\n",
    "    for j in range(5):\n",
    "        axes[i][j].imshow(gen_data[i*5 +j].cpu().squeeze(0), cmap='gray')\n",
    "        axes[i][j].axis('off')\n",
    "plt.tight_layout(pad=0.)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## ... the other samplers work the same"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Visualizing reconstructions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "reconstructions = trained_model.reconstruct(eval_dataset[:25].to(device)).detach().cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# show reconstructions\n",
    "fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))\n",
    "\n",
    "for i in range(5):\n",
    "    for j in range(5):\n",
    "        axes[i][j].imshow(reconstructions[i*5 + j].cpu().squeeze(0), cmap='gray')\n",
    "        axes[i][j].axis('off')\n",
    "plt.tight_layout(pad=0.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# show the true data\n",
    "fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))\n",
    "\n",
    "for i in range(5):\n",
    "    for j in range(5):\n",
    "        axes[i][j].imshow(eval_dataset[i*5 +j].cpu().squeeze(0), cmap='gray')\n",
    "        axes[i][j].axis('off')\n",
    "plt.tight_layout(pad=0.)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Visualizing interpolations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "interpolations = trained_model.interpolate(eval_dataset[:5].to(device), eval_dataset[5:10].to(device), granularity=10).detach().cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# show interpolations\n",
    "fig, axes = plt.subplots(nrows=5, ncols=10, figsize=(10, 5))\n",
    "\n",
    "for i in range(5):\n",
    "    for j in range(10):\n",
    "        axes[i][j].imshow(interpolations[i, j].cpu().squeeze(0), cmap='gray')\n",
    "        axes[i][j].axis('off')\n",
    "plt.tight_layout(pad=0.)"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "95022f601a219c6b6d093149c9a9b9a061a4446d3680d89cef8a1f82970031f2"
  },
  "kernelspec": {
   "display_name": "Python 3.8.11 64-bit ('pythae_dev': conda)",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
